{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "46f5f5b2-f7f9-415f-8049-3e2f66d38195",
   "metadata": {
    "tags": []
   },
   "source": [
    "# 第四章 在 LangChain 中调用 OpenAI Function Calling\n",
    "\n",
    " - [一、设置OpenAI API Key](#一、设置OpenAI-API-Key)\n",
    " - [二、Pydantic语法](#二、Pydantic语法)\n",
    "     - [2.1 简单创建Python类](#2.1-简单创建Python类)\n",
    "     - [2.2 使用Pydantic创建类](#2.2-使用Pydantic创建类)\n",
    " - [三、 基于Pydantic的OpenAI函数定义](#三、基于Pydantic的OpenAI函数定义)\n",
    "     - [3.1 构造OpenAI的function](#3.1-构造OpenAI的function)\n",
    "     - [3.2 通过langchain使用function](#3.2-通过langchain使用function)\n",
    "     - [3.3 强制使用function](#3.3-强制使用function)\n",
    " - [四、使用function](#四、使用function)\n",
    "     - [4.1 链式使用](#3.1-链式使用)\n",
    "     - [4.2 使用多个function](#4.2-使用多个function)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1eb87015",
   "metadata": {},
   "source": [
    "# 一、设置OpenAI-API-Key\n",
    "\n",
    "详细内容见`设置OpenAI_API_KEY.ipynb`文件"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ea4686a",
   "metadata": {},
   "source": [
    "# 二、Pydantic语法"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46645a3f",
   "metadata": {},
   "source": [
    "Pydantic 是 python 的“数据验证库”。\n",
    "- 与 python 类型注释一起工作。但是，与静态类型检查不同，它们在运行时被积极地用于数据验证和转换。\n",
    "- 提供内置方法来序列化/反序列化模型到 JSON ，字典等。\n",
    "- LangChain 利用 Pydantic 创建 JsONScheme 描述函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56cbff25",
   "metadata": {},
   "source": [
    "## 2.1 简单创建Python类"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "188a17b3",
   "metadata": {},
   "source": [
    "在标准python中，可以创建一个`User`类，拥有`name`、`age`、`email`三种属性值。直接进行创建，不能对字段类型进行约束，即年龄中可能传入不合规的字符串类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae8e977",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from pydantic import BaseModel, Field"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9447a47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建User类\n",
    "class User:\n",
    "    def __init__(self, name: str, age: int, email: str):\n",
    "        self.name = name\n",
    "        self.age = age\n",
    "        self.email = email"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e765e2b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Joe\n"
     ]
    }
   ],
   "source": [
    "# 生成user对象\n",
    "foo = User(name=\"Joe\",age=32, email=\"joe@gmail.com\")\n",
    "\n",
    "# 输出foo中 name字段\n",
    "print(foo.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa33be60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bar\n"
     ]
    }
   ],
   "source": [
    "# 生成user对象\n",
    "foo = User(name=\"Joe\",age=\"bar\", email=\"joe@gmail.com\")\n",
    "\n",
    "# 输出foo中 name字段\n",
    "# name字段填写的是字符串类型，但能够创建成功并输出\n",
    "print(foo.age)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccbfc8d5",
   "metadata": {},
   "source": [
    "## 2.2 使用 Pydantic 创建类"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80a59b93",
   "metadata": {},
   "source": [
    "使用 Pydantic 创建类，可以对类的属性值进行格式约束。在创建类的时候会进行格式验证，如果格式不符合要求会报错。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49675d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用 Pydantic创建pUser类别，说明age使用int类型\n",
    "class pUser(BaseModel):\n",
    "    name: str\n",
    "    age: int\n",
    "    email: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "825b261e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Jane\n"
     ]
    }
   ],
   "source": [
    "# 生成user对象\n",
    "foo_p = pUser(name=\"Jane\", age=32, email=\"jane@gmail.com\")\n",
    "\n",
    "# 输出foo中 name字段\n",
    "print(foo_p.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41a635f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 validation error for pUser\n",
      "age\n",
      "  value is not a valid integer (type=type_error.integer)\n"
     ]
    }
   ],
   "source": [
    "# 使用了pydantic，由于age不是int，因此会报错。输出报错内容\n",
    "try:\n",
    "    foo_p = pUser(name=\"Jane\", age=\"bar\", email=\"jane@gmail.com\")\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e7ec9aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Class(BaseModel):\n",
    "    students: List[pUser]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3206e1fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Class(students=[pUser(name='Jane', age=32, email='jane@gmail.com')])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "obj = Class(\n",
    "    students=[pUser(name=\"Jane\", age=32, email=\"jane@gmail.com\")]\n",
    ")\n",
    "obj"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f956e561",
   "metadata": {},
   "source": [
    "# 三、基于Pydantic的OpenAI函数定义"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06f67f9a",
   "metadata": {},
   "source": [
    "## 3.1 构造 OpenAI 的 function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8765afc9",
   "metadata": {},
   "source": [
    "我们创建了一个`WeatherSearch`类，它继承自 Pydantic 的 BaseModel 子类.因此`WeatherSearch`类的所有成员都被具备了数据类型校验机制，该类有一个str类型的成员`airport_code`表示机场代码，并有一个描述信息description，用来说明 airport_code 的作用，在 airport_code 的上方也有一段文本描述信息：输入机场代码，获取该机场的天气信息。这段文本信息是对类`WeatherSearch`的说明，意思是通过机场代码可以查询天气情况。\n",
    "\n",
    "同时，我们使用 langchain 将这个`WeatherSearch`类转换成openai的函数描述对象："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8886dc0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义WeatherSearch\n",
    "# WeatherSearch的function需要写注释完成函数的description\n",
    "class WeatherSearch(BaseModel):\n",
    "    \"\"\"输入机场代码，获取该机场的天气信息\"\"\"\n",
    "    airport_code: str = Field(description=\"输入机场代码查询天气\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26e7e42f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入langchain\n",
    "from langchain.utils.openai_functions import convert_pydantic_to_openai_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "104b30c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'WeatherSearch',\n",
       " 'description': '输入机场代码，获取该机场的天气信息',\n",
       " 'parameters': {'title': 'WeatherSearch',\n",
       "  'description': '输入机场代码，获取该机场的天气信息',\n",
       "  'type': 'object',\n",
       "  'properties': {'airport_code': {'title': 'Airport Code',\n",
       "    'description': '输入机场代码查询天气',\n",
       "    'type': 'string'}},\n",
       "  'required': ['airport_code']}}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 转化为openai function calling模式\n",
    "weather_function = convert_pydantic_to_openai_function(WeatherSearch)\n",
    "weather_function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbd9fa5c",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "接下来，使用 langchain 的`convert_pydantic_to_openai_function`方法将 Pydantic 类转换成了 openai 的函数描述对象。需要的注意的是在定义Pydantic类时文本描述信息不可缺少，如缺少文本描述信息会导致转换时出错。\n",
    "\n",
    "- `WeatherSearch1`，由于我们没有在`WeatherSearch1`中加入对本身的描述信息，导致在转换时报错。\n",
    "- `WeatherSearch2`加入对本身的描述信息，因此不会报错。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa2584c8",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'description'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)\n",
      "Cell \u001b[0;32mIn[15], line 5\u001b[0m\n",
      "\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mWeatherSearch1\u001b[39;00m(BaseModel):\n",
      "\u001b[1;32m      3\u001b[0m     airport_code: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m Field(description\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m输入机场代码查询天气\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;32m----> 5\u001b[0m convert_pydantic_to_openai_function(WeatherSearch1)\n",
      "\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.11/site-packages/langchain/utils/openai_functions.py:28\u001b[0m, in \u001b[0;36mconvert_pydantic_to_openai_function\u001b[0;34m(model, name, description)\u001b[0m\n",
      "\u001b[1;32m     24\u001b[0m schema \u001b[38;5;241m=\u001b[39m dereference_refs(model\u001b[38;5;241m.\u001b[39mschema())\n",
      "\u001b[1;32m     25\u001b[0m schema\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdefinitions\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
      "\u001b[1;32m     26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n",
      "\u001b[1;32m     27\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: name \u001b[38;5;129;01mor\u001b[39;00m schema[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtitle\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n",
      "\u001b[0;32m---> 28\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdescription\u001b[39m\u001b[38;5;124m\"\u001b[39m: description \u001b[38;5;129;01mor\u001b[39;00m schema[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdescription\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n",
      "\u001b[1;32m     29\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameters\u001b[39m\u001b[38;5;124m\"\u001b[39m: schema,\n",
      "\u001b[1;32m     30\u001b[0m }\n",
      "\n",
      "\u001b[0;31mKeyError\u001b[0m: 'description'"
     ]
    }
   ],
   "source": [
    "# WeatherSearch1没有写注释，会报错\n",
    "class WeatherSearch1(BaseModel):\n",
    "    airport_code: str = Field(description=\"输入机场代码查询天气\")\n",
    "\n",
    "convert_pydantic_to_openai_function(WeatherSearch1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92e371eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造WeatherSearch2，不对参数注释\n",
    "class WeatherSearch2(BaseModel):\n",
    "    \"\"\"输入机场代码，获取该机场的天气信息\"\"\"\n",
    "    airport_code: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b710902b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'WeatherSearch2',\n",
       " 'description': '输入机场代码，获取该机场的天气信息',\n",
       " 'parameters': {'title': 'WeatherSearch2',\n",
       "  'description': '输入机场代码，获取该机场的天气信息',\n",
       "  'type': 'object',\n",
       "  'properties': {'airport_code': {'title': 'Airport Code', 'type': 'string'}},\n",
       "  'required': ['airport_code']}}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "convert_pydantic_to_openai_function(WeatherSearch2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35ff3121",
   "metadata": {},
   "source": [
    "## 3.2 通过langchain使用function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0792e86e",
   "metadata": {},
   "source": [
    "为了实现 LLM 对 function 的调用，有一下两种方式\n",
    "\n",
    "- 在`invoke`中指定 functions\n",
    "- 执行`invoke`之前使用`bind`方法来绑定函数描述对象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cea24c65",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入ChatOpenAI\n",
    "from langchain.chat_models import ChatOpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ed0850",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ChatOpenAI()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "550029a0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n  \"airport_code\": \"SFO\"\\n}'}}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 在invoke中使用openai function功能\n",
    "model.invoke(\"今天旧金山的天气怎么样？\", functions=[weather_function])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458e21ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n  \"airport_code\": \"SFO\"\\n}'}}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 直接在构造模型中声明functions，对话时不需要在声明\n",
    "model_with_function = model.bind(functions=[weather_function])\n",
    "model_with_function.invoke(\"今天旧金山的天气怎么样？\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c6f6aed",
   "metadata": {},
   "source": [
    "## 3.3 强制使用function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dfdcad6",
   "metadata": {},
   "source": [
    "如果想要强制使用 function，需要在`bind`中增加`function_call`参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddc08c4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 强制使用function，在模型构建时声明function_call\n",
    "model_with_forced_function = model.bind(functions=[weather_function], function_call={\"name\":\"WeatherSearch\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11bb59da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n  \"airport_code\": \"SFO\"\\n}'}}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 基于已经声明的function对话，能够调用function\n",
    "model_with_forced_function.invoke(\"旧金山的天气怎么样？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcce9a68",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n  \"airport_code\": \"LAX\"\\n}'}}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 输入hi，强制使用function的模型，即使prompt与函数description无关也会命中\n",
    "model_with_forced_function.invoke(\"你好!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f805a67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='Hello! How can I assist you today?', additional_kwargs={}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 输入hi，未强制使用function的模型，prompt与函数description无关不会命中\n",
    "model_with_function.invoke(\"你好!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8c131cf",
   "metadata": {},
   "source": [
    "# 四、使用function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6002ec0d",
   "metadata": {},
   "source": [
    "在一般情况下我们会使用 chain 来实现整个问答的流程，接下来我们通过创建 chain 来实现函数调用功能"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43f0e6d4",
   "metadata": {},
   "source": [
    "## 4.1 链式使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d7d7b66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 引入需要的环境\n",
    "from langchain.prompts import ChatPromptTemplate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17df129f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用预定义模版创建聊天代码\n",
    "prompt = ChatPromptTemplate.from_messages([\n",
    "    (\"system\", \"你是一个乐于助人的助手\"),\n",
    "    (\"user\", \"{input}\")\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c70e2192",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用prompt + model_with_function 组合\n",
    "chain = prompt | model_with_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d5c71f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n  \"airport_code\": \"SFO\"\\n}'}}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 创建命中function的问答\n",
    "chain.invoke({\"input\": \"今天旧金山的天气怎么样？\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a98e8dbd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='Hello! How can I assist you today?', additional_kwargs={}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 创建未命中function的问答\n",
    "chain.invoke({\"input\": \"你好!\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19d7433a",
   "metadata": {},
   "source": [
    "## 4.2 使用多个function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1884ed9",
   "metadata": {},
   "source": [
    "我们可以传递一组函数，让 LLM 根据问题上下文决定使用哪个函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1cc30d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建 ArtistSearch function\n",
    "class ArtistSearch(BaseModel):\n",
    "    \"\"\"调用此命令可以获得特定艺术家的歌曲名称\"\"\"\n",
    "    artist_name: str = Field(description=\"要查的艺术家的名字\")\n",
    "    n: int = Field(description=\"number of results\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be15a438",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 组装WeatherSearch和ArtistSearch 函数\n",
    "functions = [\n",
    "    convert_pydantic_to_openai_function(WeatherSearch),\n",
    "    convert_pydantic_to_openai_function(ArtistSearch),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "747e93e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将两个function放入模型\n",
    "model_with_functions = model.bind(functions=functions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "276e8a25",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n  \"airport_code\": \"SFO\"\\n}'}}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 命中 WeatherSearch function\n",
    "model_with_functions.invoke(\"旧金山的天气怎么样？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1dcd0b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'name': 'ArtistSearch', 'arguments': '{\\n  \"artist_name\": \"taylor swift\",\\n  \"n\": 3\\n}'}}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 命中 ArtistSearch function\n",
    "model_with_functions.invoke(\"找到泰勒·斯威夫特的三首歌？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06f9e5d0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='Hello! How can I assist you today?', additional_kwargs={}, example=False)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 命中都未命中\n",
    "model_with_functions.invoke(\"你好!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fbd5418",
   "metadata": {},
   "source": [
    "# 五、英文版提示"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4acb9138",
   "metadata": {},
   "source": [
    "**3.1 构造OpenAI的function**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2af21fad",
   "metadata": {},
   "outputs": [],
   "source": [
    "class WeatherSearch(BaseModel):\n",
    "    \"\"\"Call this with an airport code to get the weather at that airport\"\"\"\n",
    "    airport_code: str = Field(description=\"airport code to get weather for\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c13c9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class WeatherSearch1(BaseModel):\n",
    "    airport_code: str = Field(description=\"airport code to get weather for\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd9a8ed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class WeatherSearch2(BaseModel):\n",
    "    \"\"\"Call this with an airport code to get the weather at that airport\"\"\"\n",
    "    airport_code: str"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f54d4bd2",
   "metadata": {},
   "source": [
    "**4.1 链式使用**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2b91dd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = ChatPromptTemplate.from_messages([\n",
    "    (\"system\", \"You are a helpful assistant\"),\n",
    "    (\"user\", \"{input}\")\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "259599c1",
   "metadata": {},
   "source": [
    "**4.2 使用多个function**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfe327f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ArtistSearch(BaseModel):\n",
    "    \"\"\"Call this to get the names of songs by a particular artist\"\"\"\n",
    "    artist_name: str = Field(description=\"name of artist to look up\")\n",
    "    n: int = Field(description=\"number of results\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
